// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package rowcodec_test

import (
	"math"
	"strings"
	"testing"
	"time"

	. "github.com/pingcap/check"
	"github.com/pingcap/tidb/parser/mysql"
	"github.com/pingcap/tidb/sessionctx/stmtctx"
	"github.com/pingcap/tidb/tablecodec"
	"github.com/pingcap/tidb/types"
	"github.com/pingcap/tidb/util/chunk"
	"github.com/pingcap/tidb/util/codec"
	"github.com/pingcap/tidb/util/rowcodec"
)

func TestT(t *testing.T) {
	TestingT(t)
}

var _ = Suite(&testSuite{})

type testSuite struct{}

type testData struct {
	id     int64
	ft     *types.FieldType
	dt     types.Datum
	bt     types.Datum
	def    *types.Datum
	handle bool
}

func (s *testSuite) TestDecodeRowWithHandle(c *C) {
	handleID := int64(-1)
	handleValue := int64(10000)

	encodeAndDecodeHandle := func(c *C, testData []testData) {
		// transform test data into input.
		colIDs := make([]int64, 0, len(testData))
		dts := make([]types.Datum, 0, len(testData))
		fts := make([]*types.FieldType, 0, len(testData))
		cols := make([]rowcodec.ColInfo, 0, len(testData))
		for i := range testData {
			t := testData[i]
			if !t.handle {
				colIDs = append(colIDs, t.id)
				dts = append(dts, t.dt)
			}
			fts = append(fts, t.ft)
			cols = append(cols, rowcodec.ColInfo{
				ID:         t.id,
				Tp:         int32(t.ft.Tp),
				Flag:       int32(t.ft.Flag),
				IsPKHandle: t.handle,
				Flen:       t.ft.Flen,
				Decimal:    t.ft.Decimal,
				Elems:      t.ft.Elems,
			})
		}

		// test encode input.
		var encoder rowcodec.Encoder
		sc := new(stmtctx.StatementContext)
		sc.TimeZone = time.UTC
		newRow, err := encoder.Encode(sc, colIDs, dts, nil)
		c.Assert(err, IsNil)

		// decode to datum map.
		mDecoder := rowcodec.NewDatumMapDecoder(cols, -1, sc.TimeZone)
		dm, err := mDecoder.DecodeToDatumMap(newRow, handleValue, nil)
		c.Assert(err, IsNil)
		for _, t := range testData {
			d, exists := dm[t.id]
			c.Assert(exists, IsTrue)
			c.Assert(d, DeepEquals, t.dt)
		}

		// decode to chunk.
		cDecoder := rowcodec.NewChunkDecoder(cols, -1, nil, sc.TimeZone)
		chk := chunk.New(fts, 1, 1)
		err = cDecoder.DecodeToChunk(newRow, handleValue, chk)
		c.Assert(err, IsNil)
		chkRow := chk.GetRow(0)
		cdt := chkRow.GetDatumRow(fts)
		for i, t := range testData {
			d := cdt[i]
			c.Assert(d, DeepEquals, t.bt)
		}

		// decode to old row bytes.
		colOffset := make(map[int64]int)
		for i, t := range testData {
			colOffset[t.id] = i
		}
		bDecoder := rowcodec.NewByteDecoder(cols, -1, nil, nil)
		oldRow, err := bDecoder.DecodeToBytes(colOffset, handleValue, newRow, nil)
		c.Assert(err, IsNil)
		for i, t := range testData {
			remain, d, err := codec.DecodeOne(oldRow[i])
			c.Assert(err, IsNil)
			c.Assert(len(remain), Equals, 0)
			c.Assert(d, DeepEquals, t.bt)
		}
	}

	// encode & decode signed int.
	testDataSigned := []testData{
		{
			handleID,
			types.NewFieldType(mysql.TypeLonglong),
			types.NewIntDatum(handleValue),
			types.NewIntDatum(handleValue),
			nil,
			true,
		},
		{
			10,
			types.NewFieldType(mysql.TypeLonglong),
			types.NewIntDatum(1),
			types.NewIntDatum(1),
			nil,
			false,
		},
	}
	encodeAndDecodeHandle(c, testDataSigned)

	// encode & decode unsigned int.
	testDataUnsigned := []testData{
		{
			handleID,
			withUnsigned(types.NewFieldType(mysql.TypeLonglong)),
			types.NewIntDatum(handleValue),          // decode as chunk & map, always encode it as int
			types.NewUintDatum(uint64(handleValue)), // decode as bytes will uint if unsigned.
			nil,
			true,
		},
		{
			10,
			types.NewFieldType(mysql.TypeLonglong),
			types.NewIntDatum(1),
			types.NewIntDatum(1),
			nil,
			false,
		},
	}
	encodeAndDecodeHandle(c, testDataUnsigned)
}

func (s *testSuite) TestTypesNewRowCodec(c *C) {
	encodeAndDecode := func(c *C, testData []testData) {
		// transform test data into input.
		colIDs := make([]int64, 0, len(testData))
		dts := make([]types.Datum, 0, len(testData))
		fts := make([]*types.FieldType, 0, len(testData))
		cols := make([]rowcodec.ColInfo, 0, len(testData))
		for i := range testData {
			t := testData[i]
			colIDs = append(colIDs, t.id)
			dts = append(dts, t.dt)
			fts = append(fts, t.ft)
			cols = append(cols, rowcodec.ColInfo{
				ID:         t.id,
				Tp:         int32(t.ft.Tp),
				Flag:       int32(t.ft.Flag),
				IsPKHandle: t.handle,
				Flen:       t.ft.Flen,
				Decimal:    t.ft.Decimal,
				Elems:      t.ft.Elems,
			})
		}

		// test encode input.
		var encoder rowcodec.Encoder
		sc := new(stmtctx.StatementContext)
		sc.TimeZone = time.UTC
		newRow, err := encoder.Encode(sc, colIDs, dts, nil)
		c.Assert(err, IsNil)

		// decode to datum map.
		mDecoder := rowcodec.NewDatumMapDecoder(cols, -1, sc.TimeZone)
		dm, err := mDecoder.DecodeToDatumMap(newRow, -1, nil)
		c.Assert(err, IsNil)
		for _, t := range testData {
			d, exists := dm[t.id]
			c.Assert(exists, IsTrue)
			c.Assert(d, DeepEquals, t.dt)
		}

		// decode to chunk.
		cDecoder := rowcodec.NewChunkDecoder(cols, -1, nil, sc.TimeZone)
		chk := chunk.New(fts, 1, 1)
		err = cDecoder.DecodeToChunk(newRow, -1, chk)
		c.Assert(err, IsNil)
		chkRow := chk.GetRow(0)
		cdt := chkRow.GetDatumRow(fts)
		for i, t := range testData {
			d := cdt[i]
			c.Assert(d, DeepEquals, t.dt)
		}

		// decode to old row bytes.
		colOffset := make(map[int64]int)
		for i, t := range testData {
			colOffset[t.id] = i
		}
		bDecoder := rowcodec.NewByteDecoder(cols, -1, nil, nil)
		oldRow, err := bDecoder.DecodeToBytes(colOffset, -1, newRow, nil)
		c.Assert(err, IsNil)
		for i, t := range testData {
			remain, d, err := codec.DecodeOne(oldRow[i])
			c.Assert(err, IsNil)
			c.Assert(len(remain), Equals, 0)
			c.Assert(d, DeepEquals, t.bt)
		}
	}

	testData := []testData{
		{
			1,
			types.NewFieldType(mysql.TypeLonglong),
			types.NewIntDatum(1),
			types.NewIntDatum(1),
			nil,
			false,
		},
		{
			22,
			withUnsigned(types.NewFieldType(mysql.TypeShort)),
			types.NewUintDatum(1),
			types.NewUintDatum(1),
			nil,
			false,
		},
		{
			3,
			types.NewFieldType(mysql.TypeDouble),
			types.NewFloat64Datum(2),
			types.NewFloat64Datum(2),
			nil,
			false,
		},
		{
			24,
			types.NewFieldType(mysql.TypeString),
			types.NewBytesDatum([]byte("abc")),
			types.NewBytesDatum([]byte("abc")),
			nil,
			false,
		},
		{
			12,
			types.NewFieldType(mysql.TypeYear),
			types.NewIntDatum(1999),
			types.NewIntDatum(1999),
			nil,
			false,
		},
		{
			11,
			types.NewFieldType(mysql.TypeNull),
			types.NewDatum(nil),
			types.NewDatum(nil),
			nil,
			false,
		},
		{
			2,
			types.NewFieldType(mysql.TypeNull),
			types.NewDatum(nil),
			types.NewDatum(nil),
			nil,
			false,
		},
		{
			100,
			types.NewFieldType(mysql.TypeNull),
			types.NewDatum(nil),
			types.NewDatum(nil),
			nil,
			false,
		},
		{
			116,
			types.NewFieldType(mysql.TypeFloat),
			types.NewFloat32Datum(6),
			types.NewFloat64Datum(6),
			nil,
			false,
		},
		{
			119,
			types.NewFieldType(mysql.TypeVarString),
			types.NewBytesDatum([]byte("")),
			types.NewBytesDatum([]byte("")),
			nil,
			false,
		},
	}

	// test small
	encodeAndDecode(c, testData)

	// test large colID
	testData[0].id = 300
	encodeAndDecode(c, testData)
	testData[0].id = 1

	// test large data
	testData[3].dt = types.NewBytesDatum([]byte(strings.Repeat("a", math.MaxUint16+1)))
	testData[3].bt = types.NewBytesDatum([]byte(strings.Repeat("a", math.MaxUint16+1)))
	encodeAndDecode(c, testData)
}

func (s *testSuite) TestNilAndDefault(c *C) {
	encodeAndDecode := func(c *C, testData []testData) {
		// transform test data into input.
		colIDs := make([]int64, 0, len(testData))
		dts := make([]types.Datum, 0, len(testData))
		cols := make([]rowcodec.ColInfo, 0, len(testData))
		fts := make([]*types.FieldType, 0, len(testData))
		for i := range testData {
			t := testData[i]
			if t.def == nil {
				colIDs = append(colIDs, t.id)
				dts = append(dts, t.dt)
			}
			fts = append(fts, t.ft)
			cols = append(cols, rowcodec.ColInfo{
				ID:         t.id,
				Tp:         int32(t.ft.Tp),
				Flag:       int32(t.ft.Flag),
				IsPKHandle: t.handle,
				Flen:       t.ft.Flen,
				Decimal:    t.ft.Decimal,
				Elems:      t.ft.Elems,
			})
		}
		ddf := func(i int) (types.Datum, error) {
			t := testData[i]
			if t.def == nil {
				var d types.Datum
				d.SetNull()
				return d, nil
			}
			return *t.def, nil
		}
		bdf := func(i int) ([]byte, error) {
			t := testData[i]
			if t.def == nil {
				return nil, nil
			}
			return getOldDatumByte(*t.def), nil
		}
		// test encode input.
		var encoder rowcodec.Encoder
		sc := new(stmtctx.StatementContext)
		sc.TimeZone = time.UTC
		newRow, err := encoder.Encode(sc, colIDs, dts, nil)
		c.Assert(err, IsNil)

		// decode to datum map.
		mDecoder := rowcodec.NewDatumMapDecoder(cols, -1, sc.TimeZone)
		dm, err := mDecoder.DecodeToDatumMap(newRow, -1, nil)
		c.Assert(err, IsNil)
		for _, t := range testData {
			d, exists := dm[t.id]
			if t.def != nil {
				// for datum should not fill default value.
				c.Assert(exists, IsFalse)
			} else {
				c.Assert(exists, IsTrue)
				c.Assert(d, DeepEquals, t.bt)
			}
		}

		//decode to chunk.
		chk := chunk.New(fts, 1, 1)
		cDecoder := rowcodec.NewChunkDecoder(cols, -1, ddf, sc.TimeZone)
		err = cDecoder.DecodeToChunk(newRow, -1, chk)
		c.Assert(err, IsNil)
		chkRow := chk.GetRow(0)
		cdt := chkRow.GetDatumRow(fts)
		for i, t := range testData {
			d := cdt[i]
			c.Assert(d, DeepEquals, t.bt)
		}

		// decode to old row bytes.
		colOffset := make(map[int64]int)
		for i, t := range testData {
			colOffset[t.id] = i
		}
		bDecoder := rowcodec.NewByteDecoder(cols, -1, bdf, sc.TimeZone)
		oldRow, err := bDecoder.DecodeToBytes(colOffset, -1, newRow, nil)
		c.Assert(err, IsNil)
		for i, t := range testData {
			remain, d, err := codec.DecodeOne(oldRow[i])
			c.Assert(err, IsNil)
			c.Assert(len(remain), Equals, 0)
			c.Assert(d, DeepEquals, t.bt)
		}
	}
	dtNilData := []testData{
		{
			1,
			types.NewFieldType(mysql.TypeLonglong),
			types.NewIntDatum(1),
			types.NewIntDatum(1),
			nil,
			false,
		},
		{
			2,
			withUnsigned(types.NewFieldType(mysql.TypeLonglong)),
			types.NewUintDatum(1),
			types.NewUintDatum(9),
			getDatumPoint(types.NewUintDatum(9)),
			false,
		},
	}
	encodeAndDecode(c, dtNilData)
}

func (s *testSuite) TestVarintCompatibility(c *C) {
	encodeAndDecodeByte := func(c *C, testData []testData) {
		// transform test data into input.
		colIDs := make([]int64, 0, len(testData))
		dts := make([]types.Datum, 0, len(testData))
		fts := make([]*types.FieldType, 0, len(testData))
		cols := make([]rowcodec.ColInfo, 0, len(testData))
		for i := range testData {
			t := testData[i]
			colIDs = append(colIDs, t.id)
			dts = append(dts, t.dt)
			fts = append(fts, t.ft)
			cols = append(cols, rowcodec.ColInfo{
				ID:         t.id,
				Tp:         int32(t.ft.Tp),
				Flag:       int32(t.ft.Flag),
				IsPKHandle: t.handle,
				Flen:       t.ft.Flen,
				Decimal:    t.ft.Decimal,
				Elems:      t.ft.Elems,
			})
		}

		// test encode input.
		var encoder rowcodec.Encoder
		sc := new(stmtctx.StatementContext)
		sc.TimeZone = time.UTC
		newRow, err := encoder.Encode(sc, colIDs, dts, nil)
		c.Assert(err, IsNil)
		decoder := rowcodec.NewByteDecoder(cols, -1, nil, sc.TimeZone)
		// decode to old row bytes.
		colOffset := make(map[int64]int)
		for i, t := range testData {
			colOffset[t.id] = i
		}
		oldRow, err := decoder.DecodeToBytes(colOffset, 1, newRow, nil)
		c.Assert(err, IsNil)
		for i, t := range testData {
			oldVarint, err := tablecodec.EncodeValue(nil, nil, t.bt) // tablecodec will encode as varint/varuint
			c.Assert(err, IsNil)
			c.Assert(oldVarint, DeepEquals, oldRow[i])
		}
	}

	testDataValue := []testData{
		{
			1,
			types.NewFieldType(mysql.TypeLonglong),
			types.NewIntDatum(1),
			types.NewIntDatum(1),
			nil,
			false,
		},
		{
			2,
			withUnsigned(types.NewFieldType(mysql.TypeLonglong)),
			types.NewUintDatum(1),
			types.NewUintDatum(1),
			nil,
			false,
		},
	}
	encodeAndDecodeByte(c, testDataValue)
}

func (s *testSuite) TestCodecUtil(c *C) {
	colIDs := []int64{1, 2, 3, 4}
	tps := make([]*types.FieldType, 4)
	for i := 0; i < 3; i++ {
		tps[i] = types.NewFieldType(mysql.TypeLonglong)
	}
	tps[3] = types.NewFieldType(mysql.TypeNull)
	sc := new(stmtctx.StatementContext)
	rd := &rowcodec.Encoder{}
	oldRow, err := tablecodec.EncodeRow(sc, types.MakeDatums(1, 2, 3, nil), colIDs, nil, nil, rd)
	c.Check(err, IsNil)
	var (
		rb     rowcodec.Encoder
		newRow []byte
	)
	newRow, err = rowcodec.EncodeFromOldRow(&rb, nil, oldRow, nil)
	c.Assert(err, IsNil)
	c.Assert(rowcodec.IsNewFormat(newRow), IsTrue)

	// test stringer for decoder.
	var cols []rowcodec.ColInfo
	for i, ft := range tps {
		cols = append(cols, rowcodec.ColInfo{
			ID:         colIDs[i],
			Tp:         int32(ft.Tp),
			Flag:       int32(ft.Flag),
			IsPKHandle: false,
			Flen:       ft.Flen,
			Decimal:    ft.Decimal,
			Elems:      ft.Elems,
		})
	}
	d := rowcodec.NewDecoder(cols, -1, nil)

	// test ColumnIsNull
	isNil, err := d.ColumnIsNull(newRow, 4, nil)
	c.Assert(err, IsNil)
	c.Assert(isNil, IsTrue)
	isNil, err = d.ColumnIsNull(newRow, 1, nil)
	c.Assert(err, IsNil)
	c.Assert(isNil, IsFalse)
	isNil, err = d.ColumnIsNull(newRow, 5, nil)
	c.Assert(err, IsNil)
	c.Assert(isNil, IsTrue)
	isNil, err = d.ColumnIsNull(newRow, 5, []byte{1})
	c.Assert(err, IsNil)
	c.Assert(isNil, IsFalse)

	// test isRowKey
	c.Assert(rowcodec.IsRowKey([]byte{'b', 't'}), IsFalse)
	c.Assert(rowcodec.IsRowKey([]byte{'t', 'r'}), IsFalse)
}

var (
	withUnsigned = func(ft *types.FieldType) *types.FieldType {
		ft.Flag = ft.Flag | mysql.UnsignedFlag
		return ft
	}
	getOldDatumByte = func(d types.Datum) []byte {
		b, err := tablecodec.EncodeValue(nil, nil, d)
		if err != nil {
			panic(err)
		}
		return b
	}
	getDatumPoint = func(d types.Datum) *types.Datum {
		return &d
	}
)
